# ------------------------------------------------------------------------------
# Creates shift variables and calculates leave-one-out shift test rates
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 01_leaveout_shift_test_rates.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(testit) # assert()
library(glue) # glue strings

temp <- here::here("code", "05_natural_experiment", "temp")
overnight_lab <- ""
u <- modules::use(here::here("lib", "util.R"))

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))
cohort <- readRDS(glue(paths$analysis$full_cohort))

# Shift Windows ----------------------------------------------------------------
message("Getting shift windows...")
tdelta <- as.difftime(1, units = "days")

DT <- cohort %>%
  filter(!is.na(start_datetime)) %>%
  mutate(day = lubridate::date(start_datetime)) %>%
  mutate(
    shift_12 = case_when(
      hour >= 8 & hour < 20 ~ lubridate::ymd_hms(glue("{day} 8:01:00")),
      hour >= 20 ~ lubridate::ymd_hms(glue("{day} 20:01:00")),
      hour < 8 ~ lubridate::ymd_hms(glue("{day - tdelta} 20:01:00"))
    )
  ) %>%
  mutate(
    shift_08 = case_when(
      hour < 8 ~ lubridate::ymd_hms(glue("{day} 00:01:00")),
      hour < 16 ~ lubridate::ymd_hms(glue("{day} 08:01:00")),
      TRUE ~ lubridate::ymd_hms(glue("{day} 16:01:00"))
    )
  ) %>%
  mutate(
    shift_12 = factor(as.numeric(shift_12)),
    shift_08 = factor(as.numeric(shift_08))
  ) %>%
  mutate(hour = as.factor(hour))

# Save pre-excl shifts ---------------------------------------------------------
message("Saving pre-exclusion shifts data.frame...")
write_rds(DT, file.path(temp, "shifts_df_unexcl.rds"))

# Leave-one-out Test Rates -----------------------------------------------------
message("Calculating leave-one-out test rates for each shift...")
DT <- filter(DT, !exclude)

shift_12_test_rates <- DT %>%
  group_by(shift_12) %>%
  summarize(
    shift_12_n = n(),
    shift_12_trate = mean(test_010_day)
  ) %>%
  ungroup

DT <- DT %>%
  u$safe_left_join(shift_12_test_rates)

# Leave-out calculations
DT <- DT %>%
  mutate(
    shift_12_trate_leaveout = ((shift_12_n * shift_12_trate) - test_010_day)/(shift_12_n - 1),
    hour = factor(hour)
  )

# Save -------------------------------------------------------------------------
message("Saving shifts data.frame...")
write_rds(DT, file.path(temp, "shift_test_rates.rds"))

message("Done.")
